{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "ef92c15d",
   "metadata": {},
   "source": [
    "## Metrics\n",
    "\n",
    "汇总常见2分类的指标，例如: AUC，ROC曲线，ACC, 敏感性， 特异性，精确度，召回率，PPV, NPV, F1\n",
    "\n",
    "具体的介绍，可以参考一下：https://blog.csdn.net/sunflower_sara/article/details/81214897"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "18d9b291",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from datetime import datetime\n",
    "from onekey_algo import get_param_in_cwd\n",
    "\n",
    "os.makedirs('img', exist_ok=True)\n",
    "os.makedirs('results', exist_ok=True)\n",
    "# 模型日志位置，如果没有更改默认保存位置，并且模型是当天训练出来的，可以不动这个参数。\n",
    "model_root = get_param_in_cwd('model_root', 'models')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a45a469",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy  as np\n",
    "from onekey_algo.custom.components import metrics\n",
    "from onekey_algo.custom.components.comp1 import draw_roc\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "# log_path 修改为Onekey val目录中对应的log文件。\n",
    "metric_results = []\n",
    "all_predict_scores = []\n",
    "all_gts = []\n",
    "for model in [r for r in os.listdir(model_root) if not r.endswith('.pkl')]:\n",
    "    all_pred = []\n",
    "    all_gt = []\n",
    "    for subset in ['Train', 'Test']:\n",
    "        cohort = 'TRAIN' if subset == 'Train' else 'VAL'\n",
    "        log_path = os.path.join(model_root, model, f\"viz/BST_{cohort}_RESULTS.txt\")\n",
    "        val_log = pd.read_csv(log_path, names=['fname', 'pred_score', 'pred_label', 'gt'], sep='\\t')\n",
    "        ul_labels = np.unique(val_log['pred_label'])\n",
    "        for ul in [1]:\n",
    "            pred_score = list(map(lambda x: x[0] if x[1] == ul else 1-x[0], np.array(val_log[['pred_score', 'pred_label']])))\n",
    "            gt = [1 if gt_ == ul else 0 for gt_ in np.array(val_log['gt'])]\n",
    "            acc, auc, ci, tpr, tnr, ppv, npv, precision, recall, f1, thres = metrics.analysis_pred_binary(gt, pred_score, use_youden=True)\n",
    "            ci = f\"{ci[0]:.4f}-{ci[1]:.4f}\"\n",
    "            metric_results.append([model, acc, auc, ci, tpr, tnr, ppv, npv, precision, recall, f1, thres, subset])\n",
    "            all_pred.append(np.array(list(map(lambda x: (1-x[0], x[0]) if x[1] == 1 else (x[0], 1-x[0]), \n",
    "                                          np.array(val_log[['pred_score', 'pred_label']])))))\n",
    "            all_gt.append(gt)\n",
    "    all_predict_scores.extend(all_pred)\n",
    "    all_gts.extend(all_gt)\n",
    "    draw_roc(all_gt, all_pred, labels=['Train', 'Val'], title=f\"Model: {model}\")\n",
    "    plt.savefig(f'img/DTL_{model}_roc.svg', bbox_inches='tight')\n",
    "    plt.show()\n",
    "metrics = pd.DataFrame(metric_results, \n",
    "             columns=['ModelName', 'Acc', 'AUC', '95% CI', 'Sensitivity', 'Specificity', 'PPV', 'NPV', \n",
    "                      'Precision', 'Recall', 'F1', 'Threshold', 'Cohort'])\n",
    "metrics"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1a45a539",
   "metadata": {},
   "source": [
    "# 保存预测结果\n",
    "\n",
    "将深度学习的预测结果，保存中与组学的预测结果相同的格式，f1便后续进行汇总。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06cc4c38",
   "metadata": {},
   "outputs": [],
   "source": [
    "for model in [r for r in os.listdir(model_root) if not r.endswith('.pkl')]:\n",
    "    log_path = os.path.join(model_root, model, \"viz/BST_{subset}_RESULTS.txt\")\n",
    "    train_ = pd.read_csv(log_path.format(subset='TRAIN'), names=['fname', 'pred_score', 'pred_label', 'gt'], sep='\\t')\n",
    "    train_['group'] = 'train'\n",
    "    val_ = pd.read_csv(log_path.format(subset='VAL'), names=['fname', 'pred_score', 'pred_label', 'gt'], sep='\\t')\n",
    "    val_['group'] = 'test'\n",
    "    data = pd.concat([train_, val_], axis=0)\n",
    "\n",
    "    predict = []\n",
    "    for idx, row in data.iterrows():\n",
    "        predict.append([row['fname'].replace('.png', '.gz'), \n",
    "                        row['pred_score'] if row['pred_label'] == 0 else 1 - row['pred_score'],\n",
    "                        row['pred_score'] if row['pred_label'] == 1 else 1 - row['pred_score'],\n",
    "                        row['group']])\n",
    "    predict = pd.DataFrame(predict, columns=['ID', 'label-0', 'label-1', 'group'])\n",
    "    predict[predict['group'] == 'train'][['ID', 'label-0', 'label-1']].to_csv(f'results/DTL_{model}_train.csv', index=False)\n",
    "    predict[predict['group'] == 'test'][['ID', 'label-0', 'label-1']].to_csv(f'results/DTL_{model}_test.csv', index=False)\n",
    "    predict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "933925dc",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
